#! /usr/bin/env python3

import shlex

import urlearning.BayesianNetwork as BayesianNetwork
import urlearning.Variable as Variable

class HuginNetStructureReader:

    def __init__(self):
        self.stream = None
        self.network = None
        self.token = None

    def nextToken(self):
        self.token = self.stream.get_token().lower()

    def read(self, filename):
        self.network = BayesianNetwork.BayesianNetwork()
        self.stream = shlex.shlex(open(filename))
        self.stream.wordchars += "."
        self.stream.commenters += "/"

        self.token = self.stream.get_token()

        # skip the header.... "net { <properties> }"
        # this assumes no properties contain '{'
        while self.token != '}':
            self.token = self.stream.get_token()

        # skip the end
        self.token = self.stream.get_token()

        while self.token != '':
            if self.token != 'node' and self.token != 'potential':
                raise Exception('Invalid Hugin net file.')

            if self.token == 'node':
                self.parseVariable()
            elif self.token == 'potential':
                self.parsePotential()

        return self.network

    def parseVariable(self):

        # skip 'node'
        self.nextToken()

        variableName = self.token

        #print("parsing variable: {}".format(variableName))

        self.nextToken()

        # skip everything until we find the word 'states'
        # these could be properties
        while self.token != 'states':
            self.nextToken()

        # skip 'states', '=', '('
        self.nextToken()
        self.nextToken()
        self.nextToken()

        v = self.network.addVariable(variableName)

        # now, read in the values
        # "Val0" "Val1" ... );
        # the '"' are NOT (apparently) handled during parsing (in python)
        while self.token != ')':
            # the current token is the value
            v.addValue(self.token.replace('"', ''))

            #print("Variable: {}, Adding value: {}".format(variableName, self.token.replace('"', '')))

            # move to the next token
            self.nextToken()

        # read past the ')'
        self.nextToken()

        # now read until the end of the block
        while self.token != '}':
            self.nextToken()

        # and skip the brace
        self.nextToken()


    def parsePotential(self):
        # skip 'potential', '(' and get the variable name
        self.nextToken()
        self.nextToken()
        variableName = self.token

        #print("parsing potential: '{}'".format(variableName))

        self.nextToken()

        index = self.network.getVariableIndex(variableName)

        # now, check for a '|'
        if self.token != '|':
            self.parseRootProbability(index)
        else:
            self.parseNonrootProbability(index)



    def parseRootProbability(self, index):
        # skip until 'data'
        while self.token != 'data':
            self.nextToken()

        # skip 'data', '=' '('
        self.nextToken()
        self.nextToken()
        self.nextToken()

        # update the probability table
        self.network.get(index).updateParameterSize()
        parameters = self.network.get(index).getParameters()[0]

        for k in range(self.network.getArity(index)):
            parameters[k] = float(self.token)
            self.nextToken()

        # and skip the last ")", ";" and "}"
        self.nextToken()
        self.nextToken()
        self.nextToken()

    def parseNonrootProbability(self, index):
        v = self.network.get(index)
        self.readParents(index)

        # update the parameters for this variable based on the parents
        v.updateParameterSize()

        # skip the '{', 'ďata' and '='
        self.nextToken()
        self.nextToken()
        self.nextToken()

        ins = v.getFirstInstantiation()

        # skip the first |v.parents| symbols, which are all '('
        for i in range(len(v.getParents())+1):
            self.nextToken()

        # now, read each distribution
        for i in range(v.getParametersSize()):
            self.readParameters(v, ins)

        # skip the last |v.parents| symbols, which are all ')'
        for i in range(len(v.getParents())+1):
            self.nextToken()

        # skip the final ';' and '}'
        self.nextToken()
        self.nextToken()

    def readParents(self, index):
        # read the parent list

        # skip the '|'
        self.nextToken()

        while self.token != ')':
            parentName = self.token
            pIndex = self.network.getVariableIndex(parentName)
            self.network.get(index).addParent(pIndex)

            # and skip to the next token
            self.nextToken()

        # skip the ')'
        self.nextToken()

    def readParameters(self, var, ins):
        # get the appropriate distribution based on the parents
        instantiationIndex = var.getParentIndex(ins)
        parameters = var.getParameters()[instantiationIndex]

        # update the probability values
        for k in range(var.getArity()):
            parameters[k] = float(self.token)
            self.nextToken()

        # get the next instantiation
        changeCount = var.getNextParentInstantiation(ins)

        # skip the parentheses
        for i in range(2*(changeCount +1)):
            self.nextToken()

